# -*- coding: utf-8 -*-
import hashlib
import json
import base64

from Crypto.Cipher import PKCS1_v1_5
from Crypto.Hash import SHA, MD5
from Crypto.Signature import PKCS1_v1_5 as sign_pkcs1_v1_5
from Crypto.PublicKey import RSA
from django.conf import settings

from common.order.db import get_pay
from common.cache import redis_cache
from common.channel.admin_db import get_channel
from common.utils import track_logging
from common.order.model import PAY_STATUS
from common.utils.exceptions import (NotResponsePayIdError, NotPayOrderError, ProcessedPayOrderError, SignError,
                                     MoneyNotMatchError, NotChannelError, AppIdNotMatchError, MoneyMoreMaxError)
from common.utils.ip_address import check_valid_ip_address

_LOGGER = track_logging.getLogger(__name__)


class BasePay(object):

    def __init__(self):
        pass

    @staticmethod
    def get_api_key(app_conf, mch_id, key_name='API_KEY'):
        """
        获取渠道给的密钥
        :param app_conf: 渠道给的商户信息
        :param mch_id: 商户号
        :param key_name: 密钥对应的变量名
        :return: 渠道密钥
        """
        return app_conf[mch_id][key_name]

    @staticmethod
    def get_gateway(app_conf, gateway='gateway'):
        """
        获取渠道下单地址
        :param app_conf: 渠道给的商户信息
        :param gateway: 下单地址的变量名
        :return: 渠道下单地址
        """
        return app_conf[gateway]

    @staticmethod
    def get_query_geteway(app_conf, query_gateway='query_gateway'):
        """
        获取渠道订单查询地址
        :param app_conf: 渠道给的商户信息
        :param query_gateway: 订单查询地址的变量名
        :return: 渠道订单查询地址
        """
        return app_conf[query_gateway]

    @staticmethod
    def check_channel_order(pay_id, money, appid):
        """
        检查回调/查询返回的信息。
        :param pay_id: 支付订单id
        :param money: 实际交易金额
        :param appid: 商户号
        """
        _LOGGER.info('check_channel_order : %s, %s, %s', pay_id, money, appid)
        pay = get_pay(pay_id)
        if not pay:
            raise NotPayOrderError('check_channel_order pay_id: %s invalid' % pay_id)
        if money > pay.total_fee + 2:
            raise MoneyNotMatchError(
                'check_channel_order money: %s invalid, total_fee: %s' % (money, pay.total_fee))

        channel = get_channel(pay.channel_id)
        if not channel:
            raise NotChannelError('check_channel_order channel: %s invalid' % pay.channel_id)
        channel_info = json.loads(channel.info)
        if appid != channel_info['app_id']:
            raise AppIdNotMatchError(
                'check_channel_order appid: %s invalid, channel_appid' % (appid, channel_info['app_id']))
        # if money < channel_info.get('min_amount', 1) - 1 or money > channel_info.get('max_amount', 60000):
        if money > channel_info.get('max_amount', 60000) + 2:
            raise MoneyMoreMaxError('check_channel_order money: %s invalid' % money)

    @staticmethod
    def check_ip_pay_status(request, pay_id, channel_name, data):
        """
        检查回调IP是否在白名单，回调是否返回pay_id，判断支付状态(已支付成功直接返回，其他已操作的状态抛出异常)
        :param request:
        :param pay_id: 支付订单id
        :param channel_name: 渠道名
        :param data: 渠道异步回调的数据
        :return:
            订单已经处理并且支付成功，直接返回
            pay: 无任何异常并且未订单未处理，返回订单对象
        """
        if not pay_id:
            _LOGGER.error("%s notify fatal error, pay object not exists, data: %s" % (channel_name, data))
            raise NotResponsePayIdError('event does not contain valid pay ID')
        check_valid_ip_address(str(request.META['REMOTE_ADDR']), pay_id)
        pay = get_pay(pay_id)
        if not pay:
            raise NotPayOrderError('pay_id: %s invalid' % pay_id)
        if pay.status != PAY_STATUS.READY:
            if pay.status == PAY_STATUS.SUCC:
                _LOGGER.info('%s %s has been processed and pay succeeded' % (channel_name, pay_id))
                return
            raise ProcessedPayOrderError('pay %s has been processed and not pay success' % pay_id)
        return pay

    @staticmethod
    def get_pay_type(service, pay_type_dict):
        """
        justpay与渠道定义支付方式的转化
        justpay对应支付方式的命名规则:
            支付宝：alipay  支付宝h5/wap：alipay_h5  支付宝原生：alipay_original
            微信：wechat  微信h5/wap：wechat_h5  微信原生：wechat_original  微信小程序：wxapp
            云闪付/银联扫码：cloud_flash  银联快捷：union_quick
            qq：qq  qq h5/wap：qq_h5
            京东：jd  京东h5/wap：jd_h5

        :param service: justpay定义的支付方式
        :param pay_type_dict: justpay对应渠道支付方式的字典(key对应jusypay,value对应渠道方)
        :return: 渠道定义的支付方式
        """
        if service == 'alipay' and pay_type_dict.get('alipay'):  # 支付宝
            return pay_type_dict.get('alipay')
        elif service == 'alipay_h5' and pay_type_dict.get('alipay_h5'):  # 支付宝h5/wap
            return pay_type_dict.get('alipay_h5')
        elif service == 'alipay_original' and pay_type_dict.get('alipay_original'):  # 支付宝原生
            return pay_type_dict.get('alipay_original')
        elif service == 'alipay_original_scan' and pay_type_dict.get('alipay_original_scan'):  # 支付宝原生扫码
            return pay_type_dict.get('alipay_original_scan')
        elif service == 'wechat' and pay_type_dict.get('wechat'):  # 微信
            return pay_type_dict.get('wechat')
        elif service == 'wechat_h5' and pay_type_dict.get('wechat_h5'):  # 微信h5/wap
            return pay_type_dict.get('wechat_h5')
        elif service == 'wechat_original' and pay_type_dict.get('wechat_original'):  # 微信原生
            return pay_type_dict.get('wechat_original')
        elif service == 'wechat_original_scan' and pay_type_dict.get('wechat_original_scan'):  # 微信原生扫码
            return pay_type_dict.get('wechat_original_scan')
        elif service == 'wxapp' and pay_type_dict.get('wxapp'):  # 微信小程序
            return pay_type_dict.get('wxapp')
        elif service == 'cloud_flash' and pay_type_dict.get('cloud_flash'):  # 云闪付(银联扫码)
            return pay_type_dict.get('cloud_flash')
        elif service == 'union_quick' and pay_type_dict.get('union_quick'):  # 银联快捷
            return pay_type_dict.get('union_quick')
        elif service == 'qq' and pay_type_dict.get('qq'):  # qq
            return pay_type_dict.get('qq')
        elif service == 'qq_h5' and pay_type_dict.get('qq_h5'):  # qq h5/wap
            return pay_type_dict.get('qq_h5')
        elif service == 'jd' and pay_type_dict.get('jd'):  # 京东
            return pay_type_dict.get('jd')
        elif service == 'jd_h5' and pay_type_dict.get('jd_h5'):  # 京东h5/wap
            return pay_type_dict.get('jd_h5')
        else:
            return pay_type_dict.get('alipay', 'alipay')

    @staticmethod
    def html_build_form(params, gateway, method='post'):
        """
        请求页面直接跳转方式
        :param params: 请求到渠道的信息
        :param gateway: 请求渠道的地址
        :param method: 请求方式get/post
        :return: html格式的字符串
        """
        html = u"<head><title>loading...</title></head><form id='submit' name='submit' action='" + \
               gateway + "' method='" + method + "'>"

        for k, v in params.items():
            html += "<input type='text' hidden name='%s' value='%s'/>" % (k, v)
        html += "</form>"
        html += "<script>doc" \
                "ument.forms['submit'].submit();</script>"
        return html

    def get_url(self, params, gateway, pay, method='post'):
        """
        获取通过页面调整的url
        :param params: 请求到渠道的信息
        :param gateway: 请求渠道的地址
        :param pay: 订单
        :param method: 请求方式get/post
        :return: url
        """
        html = self.html_build_form(params, gateway, method)
        cache_id = redis_cache.save_html(pay.id, html)
        url = settings.PAY_CACHE_URL + cache_id
        return url

    @staticmethod
    def get_device_ip(info):
        """
        支付用户的ip
        :param info: 通道配置数据
        :return: 支付用户的ip
        """
        try:
            extra = json.loads(info['extra'])
        except:
            extra = {}
        user_info = extra.get('user_info', {})
        return user_info.get('device_ip') or '127.0.0.1'


class SignMd5(object):
    def __init__(self):
        pass

    @staticmethod
    def pending_str(parameter, key, key_name='key', allow='no', not_sign=[]):
        """
        常用md5签名。待签名格式，按照字典序(ascii码从小到大排序)
        :param parameter: 字典数据
        :param key: 商户密钥
        :param key_name: 待签名字符串的商户密钥变量名，每个渠道可能会有不同的定义
        :param allow: 是否允许空值加入到待签名字符串中，默认no不允许,yes允许
        :param not_sign:不参与签名的keys
        :return: pending_str: 待签名字符串
        """
        pending_str = ''
        for k, v in sorted(parameter.items()):
            if allow == 'no' and not v:
                continue
            if k in not_sign:
                continue
            pending_str += '%s=%s&' % (k, v)
        pending_str += '%s=%s' % (key_name, key)
        return pending_str

    @staticmethod
    def md5_sign(sign_str):
        """
        md5签名
        :param sign_str: 待签名字符串
        :return: md5签名后的字符串
        """
        m = hashlib.md5()
        m.update(sign_str.encode('utf-8'))
        return m.hexdigest()

    @staticmethod
    def md5_sign_upper(sign_str):
        """
        md5签名，转大写
        :param sign_str: 待签名字符串
        :return: md5签名后转大写
        """
        m = hashlib.md5()
        m.update(sign_str.encode('utf-8'))
        return m.hexdigest().upper()


class SignRSABase64(object):
    """
    rsa加密，解密，加签，验签，适合base64编码
    """

    def __init__(self):
        pass

    @staticmethod
    def cipher_hash(data, sign_hash):
        """
        hash签名，先接入常用的SHA和MD5，后续有其他方式可新增
        :param data: 加签数据
        :param sign_hash: hash方式
        :return:
        """
        if sign_hash == 'MD5':
            cipher = MD5.new(data)
        else:
            cipher = SHA.new(data)
        return cipher

    def rsa_sign(self, data, pri_key, sign_hash='SHA'):
        """
        私钥加签，最后签名base64编码
        :param data: 回调的数据
        :param pri_key: 私钥
        :param sign_hash: hash类型
        :return: sign:base64签名
        """
        key_bytes = base64.b64decode(pri_key)
        pri = RSA.importKey(key_bytes)
        cipher = self.cipher_hash(data, sign_hash)

        sign = base64.b64encode(sign_pkcs1_v1_5.new(pri).sign(cipher))
        return sign

    def rsa_verify(self, sign, data, pub_key, sign_hash):
        """
        公钥验签
        :param sign: 回调的签名，经过了base64编码
        :param data: 回调的数据
        :param pub_key: 公钥
        :param sign_hash: hash类型
        :raise:
         SignError: 验签未通过
        """
        key_bytes = base64.b64decode(pub_key)
        pub = RSA.importKey(key_bytes)
        cipher = self.cipher_hash(data, sign_hash)
        t = sign_pkcs1_v1_5.new(pub).verify(cipher, base64.b64decode(sign))
        if not t:
            _LOGGER.info("jiefupay sign not pass")
            raise SignError('jiefupay sign not pass')

    @staticmethod
    def rsa_encrypt(byte_size, biz_content, public_key):
        """
        公钥加密，单次加密长度有限制，超过需要多次加密
        :param byte_size: 密钥证书字节
        :param biz_content: 加密数据
        :param public_key: 公钥
        :return: base64编码的加密字符串
        """
        key_bytes = base64.b64decode(public_key)
        _p = RSA.importKey(key_bytes)
        biz_content = biz_content.encode('utf-8')
        default_encrypt_length = byte_size / 8 - 11
        len_content = len(biz_content)
        if len_content < default_encrypt_length:
            return base64.b64encode(PKCS1_v1_5.new(_p).encrypt(biz_content))
        offset = 0
        params_lst = []
        while len_content - offset > 0:
            if len_content - offset > default_encrypt_length:
                params_lst.append(PKCS1_v1_5.new(_p).encrypt(biz_content[offset:offset + default_encrypt_length]))
            else:
                params_lst.append(PKCS1_v1_5.new(_p).encrypt(biz_content[offset:]))
            offset += default_encrypt_length
        target = ''.join(params_lst)
        return base64.b64encode(target)

    @staticmethod
    def rsa_decrypt(byte_size, biz_content, private_key):
        """
        私钥解密，解密长度有限制，超过需要多次解密
        :param byte_size: 密钥证书字节
        :param biz_content: 回调的加密数据，经过了base64编码
        :param private_key: 私钥
        :return: 解密后的数据
        """
        key_bytes = base64.b64decode(private_key)
        _pri = RSA.importKey(key_bytes)
        biz_content = base64.b64decode(biz_content.encode('utf-8'))
        default_length = byte_size / 8
        len_content = len(biz_content)
        if len_content < default_length:
            return PKCS1_v1_5.new(_pri).decrypt(biz_content, 'ERROR')
        offset = 0
        params_lst = []
        while len_content - offset > 0:
            if len_content - offset > default_length:
                params_lst.append(
                    PKCS1_v1_5.new(_pri).decrypt(biz_content[offset: offset + default_length], 'ERROR'))
            else:
                params_lst.append(PKCS1_v1_5.new(_pri).decrypt(biz_content[offset:], 'ERROR'))
            offset += default_length
        target = ''.join(params_lst)
        return target
